import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
from PIL import Image
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import re

mnist = input_data.read_data_sets("./tensorflow/MNIST_data", one_hot=True)

'''
x不是一个特定的值，而是一个占位符placeholder，我们在TensorFlow运行计算时输入这个值。
我们希望能够输入任意数量的MNIST图像，每一张图展平成784维的向量。我们用2维的浮点数张量来表示这些图，
这个张量的形状是[None，784 ]。（这里的None表示此张量的第一个维度可以是任何长度的。）
'''
x = tf.placeholder(tf.float32, [None, 784])

'''
用全为零的张量来初始化W和b
'''
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
'''
训练模型
用tf.matmul(​​X，W)表示x乘以W，对应之前等式里面的
'''
y = tf.nn.softmax(tf.matmul(x,W) + b)

# #评估指标
# y_ = tf.placeholder("float", [None,10])
# #交叉熵计算
# cross_entropy = -tf.reduce_sum(y_*tf.log(y))
# '''
# 用梯度下降算法（gradient descent algorithm）以0.01的学习速率最小化交叉熵
# '''
# train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
# # Add an op to initialize the variables.
# init = tf.initialize_all_variables()

#saver model
model_saver = tf.train.Saver()

# sess = tf.Session()
# sess.run(init)

# for i in range(1000):
#     batch_xs, batch_ys = mnist.train.next_batch(100)
#     sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
#
# correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
# accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
# print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
#
# # create dir for model saver
# model_path = "./tensorflow/tmp/model.ckpt"
# save_path =model_saver.save(sess, model_path)
# print("Model saved in file: ", save_path)

# #图片读取
# def readImages(filename):
#     label=1
#     # imageList = []
#     with tf.Session() as sess:
#         # for filename in filenameList:
#             # print(filename)
#             # 读取图像的原始数据
#             image_raw_data = tf.gfile.FastGFile(filename,'rb').read()  # 必须是 ‘rb’ 模式打开，否则会报错
#             # 将图像使用 jpeg 的格式解码从而得到图像对应的三维矩阵
#             # tf.image.decode_jpeg 函数对 png 格式的图像进行解码。解码之后的结果为一个张量，
#             # 在使用它的取值之前需要明确调用运行的过程。
#             print(filename)
#             img_data = tf.image.decode_jpeg(image_raw_data)
#             # arr = np.reshape(img_data.eval(sess), [-1])  # 多维矩阵转一维矩阵
#             arr = sess.run(tf.reshape(img_data.eval(), [-1]))
#             # imageList.append(arr)
#             print(tf.shape(arr))
#     return np.array(arr),label

def image_to_array(path):
    im = Image.open(path)
    # w, h = im.size
    # r, g, b = im.split()  # rgb通道分离
    r_arr = np.array(im).reshape(-1,28*28)
    # g_arr = np.array(g).reshape(28)
    # b_arr = np.array(b).reshape(28)
    # plt.imshow(im)
    # plt.show()
    if(np.shape(r_arr)[0]>1):
        r_arr = r_arr[0].reshape(-1,28*28)
    return r_arr,1
#Launch the gtrph
with tf.Session() as sess:
    #create dir for model saver
    model_path = "./tensorflow/tmp/model.ckpt"
    model_saver.restore(sess,model_path)

    # img=mnist.test.images[20].reshape(-1,784)
    # img_label=sess.run(tf.argmax(mnist.test.labels[20]))
    image_path  = "./tensorflow/mnist_digits_images/2.jpg"
    img,img_label = image_to_array(image_path)
    print(np.shape(img))
    ret=sess.run(y,feed_dict={x:img})
    num_pred=sess.run(tf.argmax(ret,1))

    print("预测值:%d\n" % num_pred)
    print("真实值:",img_label)
    print("模型恢复成功")
